/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/* \file
   \brief Execution environment
*/

#include <cstring>

#include "cutlass/numeric_types.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"

#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"

#include "cutlass/library/util.h"

#include "device_allocation.h"

namespace cutlass {
namespace profiler {

/////////////////////////////////////////////////////////////////////////////////////////////////

size_t DeviceAllocation::bytes(library::NumericTypeID type, size_t capacity) {
    return size_t(cutlass::library::sizeof_bits(type)) * capacity / 8;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Layout>
static std::vector<int> get_packed_layout_stride(
        std::vector<int> const& extent) {
    typename Layout::TensorCoord extent_coord;
    typename Layout::Stride stride_coord;

    if (extent.size() != size_t(Layout::kRank)) {
        throw std::runtime_error(
                "Layout does not have same rank as extent vector.");
    }

    for (int i = 0; i < Layout::kRank; ++i) {
        extent_coord[i] = extent.at(i);
    }

    std::vector<int> stride;
    stride.resize(Layout::kStrideRank, 0);

    Layout layout = Layout::packed(extent_coord);
    stride_coord = layout.stride();

    for (int i = 0; i < Layout::kStrideRank; ++i) {
        stride.at(i) = stride_coord[i];
    }

    return stride;
}

/// Returns the stride of a packed layout
std::vector<int> DeviceAllocation::get_packed_layout(
        library::LayoutTypeID layout_id, std::vector<int> const& extent) {
    std::vector<int> stride;

    switch (layout_id) {
        case library::LayoutTypeID::kColumnMajor:
            stride = get_packed_layout_stride<cutlass::layout::ColumnMajor>(
                    extent);
            break;
        case library::LayoutTypeID::kRowMajor:
            stride =
                    get_packed_layout_stride<cutlass::layout::RowMajor>(extent);
            break;
        case library::LayoutTypeID::kColumnMajorInterleavedK2:
            stride = get_packed_layout_stride<
                    cutlass::layout::ColumnMajorInterleaved<2>>(extent);
            break;
        case library::LayoutTypeID::kRowMajorInterleavedK2:
            stride = get_packed_layout_stride<
                    cutlass::layout::RowMajorInterleaved<2>>(extent);
            break;
        case library::LayoutTypeID::kColumnMajorInterleavedK4:
            stride = get_packed_layout_stride<
                    cutlass::layout::ColumnMajorInterleaved<4>>(extent);
            break;
        case library::LayoutTypeID::kRowMajorInterleavedK4:
            stride = get_packed_layout_stride<
                    cutlass::layout::RowMajorInterleaved<4>>(extent);
            break;
        case library::LayoutTypeID::kColumnMajorInterleavedK16:
            stride = get_packed_layout_stride<
                    cutlass::layout::ColumnMajorInterleaved<16>>(extent);
            break;
        case library::LayoutTypeID::kRowMajorInterleavedK16:
            stride = get_packed_layout_stride<
                    cutlass::layout::RowMajorInterleaved<16>>(extent);
            break;
        case library::LayoutTypeID::kColumnMajorInterleavedK32:
            stride = get_packed_layout_stride<
                    cutlass::layout::ColumnMajorInterleaved<32>>(extent);
            break;
        case library::LayoutTypeID::kRowMajorInterleavedK32:
            stride = get_packed_layout_stride<
                    cutlass::layout::RowMajorInterleaved<32>>(extent);
            break;
        case library::LayoutTypeID::kColumnMajorInterleavedK64:
            stride = get_packed_layout_stride<
                    cutlass::layout::ColumnMajorInterleaved<64>>(extent);
            break;
        case library::LayoutTypeID::kRowMajorInterleavedK64:
            stride = get_packed_layout_stride<
                    cutlass::layout::RowMajorInterleaved<64>>(extent);
            break;
        case library::LayoutTypeID::kTensorNCHW:
            stride = get_packed_layout_stride<cutlass::layout::TensorNCHW>(
                    extent);
            break;
        case library::LayoutTypeID::kTensorNHWC:
            stride = get_packed_layout_stride<cutlass::layout::TensorNHWC>(
                    extent);
            break;
        case library::LayoutTypeID::kTensorNDHWC:
            stride = get_packed_layout_stride<cutlass::layout::TensorNDHWC>(
                    extent);
            break;
        case library::LayoutTypeID::kTensorNC32HW32:
            stride =
                    get_packed_layout_stride<cutlass::layout::TensorNCxHWx<32>>(
                            extent);
            break;
        case library::LayoutTypeID::kTensorNC64HW64:
            stride =
                    get_packed_layout_stride<cutlass::layout::TensorNCxHWx<64>>(
                            extent);
            break;
        case library::LayoutTypeID::kTensorC32RSK32:
            stride =
                    get_packed_layout_stride<cutlass::layout::TensorCxRSKx<32>>(
                            extent);
            break;
        case library::LayoutTypeID::kTensorC64RSK64:
            stride =
                    get_packed_layout_stride<cutlass::layout::TensorCxRSKx<64>>(
                            extent);
            break;
        default:
            break;
    }

    return stride;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Template to use CUTLASS Layout functions to
template <typename Layout>
static size_t construct_layout_(void* bytes, library::LayoutTypeID layout_id,
                                std::vector<int> const& extent,
                                std::vector<int>& stride) {
    if (extent.size() != Layout::kRank) {
        throw std::runtime_error(
                "Layout must have same rank as extent vector.");
    }

    if (Layout::kStrideRank && stride.empty()) {
        stride = get_packed_layout_stride<Layout>(extent);

        return construct_layout_<Layout>(bytes, layout_id, extent, stride);
    } else if (Layout::kStrideRank && stride.size() != Layout::kStrideRank) {
        throw std::runtime_error(
                "Layout requires either empty stride or stride vector matching "
                "Layout::kStrideRank");
    }

    typename Layout::Stride stride_coord;
    for (int i = 0; i < Layout::kStrideRank; ++i) {
        stride_coord[i] = stride.at(i);
    }

    typename Layout::TensorCoord extent_coord;
    for (int i = 0; i < Layout::kRank; ++i) {
        extent_coord[i] = extent.at(i);
    }

    // Construct the CUTLASS layout object from the stride object
    Layout layout(stride_coord);

    // Pack it into bytes
    if (bytes) {
        *reinterpret_cast<Layout*>(bytes) = layout;
    }

    // Return capacity
    size_t capacity_ = layout.capacity(extent_coord);

    return capacity_;
}

/// returns the capacity needed
size_t DeviceAllocation::construct_layout(void* bytes,
                                          library::LayoutTypeID layout_id,
                                          std::vector<int> const& extent,
                                          std::vector<int>& stride) {
    switch (layout_id) {
        case library::LayoutTypeID::kColumnMajor:
            return construct_layout_<cutlass::layout::ColumnMajor>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kRowMajor:
            return construct_layout_<cutlass::layout::RowMajor>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kColumnMajorInterleavedK2:
            return construct_layout_<
                    cutlass::layout::ColumnMajorInterleaved<2>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kRowMajorInterleavedK2:
            return construct_layout_<cutlass::layout::RowMajorInterleaved<2>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kColumnMajorInterleavedK4:
            return construct_layout_<
                    cutlass::layout::ColumnMajorInterleaved<4>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kRowMajorInterleavedK4:
            return construct_layout_<cutlass::layout::RowMajorInterleaved<4>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kColumnMajorInterleavedK16:
            return construct_layout_<
                    cutlass::layout::ColumnMajorInterleaved<16>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kRowMajorInterleavedK16:
            return construct_layout_<cutlass::layout::RowMajorInterleaved<16>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kColumnMajorInterleavedK32:
            return construct_layout_<
                    cutlass::layout::ColumnMajorInterleaved<32>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kRowMajorInterleavedK32:
            return construct_layout_<cutlass::layout::RowMajorInterleaved<32>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kColumnMajorInterleavedK64:
            return construct_layout_<
                    cutlass::layout::ColumnMajorInterleaved<64>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kRowMajorInterleavedK64:
            return construct_layout_<cutlass::layout::RowMajorInterleaved<64>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kTensorNCHW:
            return construct_layout_<cutlass::layout::TensorNHWC>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kTensorNHWC:
            return construct_layout_<cutlass::layout::TensorNHWC>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kTensorNDHWC:
            return construct_layout_<cutlass::layout::TensorNDHWC>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kTensorNC32HW32:
            return construct_layout_<cutlass::layout::TensorNCxHWx<32>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kTensorNC64HW64:
            return construct_layout_<cutlass::layout::TensorNCxHWx<64>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kTensorC32RSK32:
            return construct_layout_<cutlass::layout::TensorCxRSKx<32>>(
                    bytes, layout_id, extent, stride);

        case library::LayoutTypeID::kTensorC64RSK64:
            return construct_layout_<cutlass::layout::TensorCxRSKx<64>>(
                    bytes, layout_id, extent, stride);

        default:
            break;
    }

    return 0;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

DeviceAllocation::DeviceAllocation()
        : type_(library::NumericTypeID::kInvalid),
          batch_stride_(0),
          capacity_(0),
          pointer_(nullptr),
          layout_(library::LayoutTypeID::kUnknown),
          batch_count_(1) {}

DeviceAllocation::DeviceAllocation(library::NumericTypeID type, size_t capacity)
        : type_(type),
          batch_stride_(capacity),
          capacity_(capacity),
          pointer_(nullptr),
          layout_(library::LayoutTypeID::kUnknown),
          batch_count_(1) {
    cudaError_t result = cudaMalloc((void**)&pointer_, bytes(type, capacity));

    if (result != cudaSuccess) {
        type_ = library::NumericTypeID::kInvalid;
        capacity_ = 0;
        pointer_ = nullptr;
        throw std::bad_alloc();
    }
}

DeviceAllocation::DeviceAllocation(library::NumericTypeID type,
                                   library::LayoutTypeID layout_id,
                                   std::vector<int> const& extent,
                                   std::vector<int> const& stride,
                                   int batch_count)
        : type_(type),
          batch_stride_(size_t(0)),
          capacity_(size_t(0)),
          pointer_(nullptr),
          batch_count_(1) {
    reset(type, layout_id, extent, stride, batch_count);
}

DeviceAllocation::~DeviceAllocation() {
    if (pointer_) {
        cudaFree(pointer_);
    }
}

DeviceAllocation& DeviceAllocation::reset() {
    if (pointer_) {
        cudaFree(pointer_);
    }

    type_ = library::NumericTypeID::kInvalid;
    batch_stride_ = 0;
    capacity_ = 0;
    pointer_ = nullptr;
    layout_ = library::LayoutTypeID::kUnknown;
    stride_.clear();
    extent_.clear();
    tensor_ref_buffer_.clear();
    batch_count_ = 1;

    return *this;
}

DeviceAllocation& DeviceAllocation::reset(library::NumericTypeID type,
                                          size_t capacity) {
    reset();

    type_ = type;
    batch_stride_ = capacity;
    capacity_ = capacity;

    cudaError_t result = cudaMalloc((void**)&pointer_, bytes(type_, capacity_));
    if (result != cudaSuccess) {
        throw std::bad_alloc();
    }

    layout_ = library::LayoutTypeID::kUnknown;
    stride_.clear();
    extent_.clear();
    batch_count_ = 1;

    tensor_ref_buffer_.resize(sizeof(pointer_), 0);
    std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_));

    return *this;
}

/// Allocates memory for a given layout and tensor
DeviceAllocation& DeviceAllocation::reset(library::NumericTypeID type,
                                          library::LayoutTypeID layout_id,
                                          std::vector<int> const& extent,
                                          std::vector<int> const& stride,
                                          int batch_count) {
    reset();

    tensor_ref_buffer_.resize(
            sizeof(pointer_) +
                    (sizeof(int) * library::get_layout_stride_rank(layout_id)),
            0);

    type_ = type;

    layout_ = layout_id;
    stride_ = stride;
    extent_ = extent;
    batch_count_ = batch_count;

    batch_stride_ =
            construct_layout(tensor_ref_buffer_.data() + sizeof(pointer_),
                             layout_id, extent, stride_);

    capacity_ = batch_stride_ * batch_count_;

    cudaError_t result = cudaMalloc((void**)&pointer_, bytes(type, capacity_));
    if (result != cudaSuccess) {
        throw std::bad_alloc();
    }

    std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_));

    return *this;
}

bool DeviceAllocation::good() const {
    return (capacity_ && pointer_);
}

library::NumericTypeID DeviceAllocation::type() const {
    return type_;
}

void* DeviceAllocation::data() const {
    return pointer_;
}

void* DeviceAllocation::batch_data(int batch_idx) const {
    return static_cast<char*>(data()) + batch_stride_bytes() * batch_idx;
}

library::LayoutTypeID DeviceAllocation::layout() const {
    return layout_;
}

std::vector<int> const& DeviceAllocation::stride() const {
    return stride_;
}

/// Gets the extent vector
std::vector<int> const& DeviceAllocation::extent() const {
    return extent_;
}

/// Gets the number of adjacent tensors in memory
int DeviceAllocation::batch_count() const {
    return batch_count_;
}

/// Gets the stride (in units of elements) beteween items
int64_t DeviceAllocation::batch_stride() const {
    return batch_stride_;
}

/// Gets the stride (in units of bytes) beteween items
int64_t DeviceAllocation::batch_stride_bytes() const {
    return bytes(type_, batch_stride_);
}

size_t DeviceAllocation::capacity() const {
    return capacity_;
}

size_t DeviceAllocation::bytes() const {
    return bytes(type_, capacity_);
}

/// Copies from an equivalent-sized tensor in device memory
void DeviceAllocation::copy_from_device(void const* ptr) {
    cudaError_t result =
            cudaMemcpy(data(), ptr, bytes(), cudaMemcpyDeviceToDevice);
    if (result != cudaSuccess) {
        throw std::runtime_error("Failed device-to-device copy");
    }
}

/// Copies from an equivalent-sized tensor in device memory
void DeviceAllocation::copy_from_host(void const* ptr) {
    cudaError_t result =
            cudaMemcpy(data(), ptr, bytes(), cudaMemcpyHostToDevice);
    if (result != cudaSuccess) {
        throw std::runtime_error("Failed device-to-device copy");
    }
}

/// Copies from an equivalent-sized tensor in device memory
void DeviceAllocation::copy_to_host(void* ptr) {
    cudaError_t result =
            cudaMemcpy(ptr, data(), bytes(), cudaMemcpyDeviceToHost);
    if (result != cudaSuccess) {
        throw std::runtime_error("Failed device-to-device copy");
    }
}

void DeviceAllocation::initialize_random_device(int seed, Distribution dist) {
    if (!good()) {
        throw std::runtime_error(
                "Attempting to initialize invalid allocation.");
    }

    // Instantiate calls to CURAND here. This file takes a long time to compile
    // for this reason.

    switch (type_) {
        case library::NumericTypeID::kF16:
            cutlass::reference::device::BlockFillRandom<cutlass::half_t>(
                    reinterpret_cast<cutlass::half_t*>(pointer_), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kBF16:
            cutlass::reference::device::BlockFillRandom<cutlass::bfloat16_t>(
                    reinterpret_cast<cutlass::bfloat16_t*>(pointer_), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kTF32:
            cutlass::reference::device::BlockFillRandom<cutlass::tfloat32_t>(
                    reinterpret_cast<cutlass::tfloat32_t*>(pointer_), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kF32:
            cutlass::reference::device::BlockFillRandom<float>(
                    reinterpret_cast<float*>(pointer_), capacity_, seed, dist);
            break;
        case library::NumericTypeID::kCBF16:
            cutlass::reference::device::BlockFillRandom<complex<bfloat16_t>>(
                    reinterpret_cast<complex<bfloat16_t>*>(pointer_), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kCTF32:
            cutlass::reference::device::BlockFillRandom<
                    cutlass::complex<cutlass::tfloat32_t>>(
                    reinterpret_cast<cutlass::complex<cutlass::tfloat32_t>*>(
                            pointer_),
                    capacity_, seed, dist);
            break;
        case library::NumericTypeID::kCF32:
            cutlass::reference::device::BlockFillRandom<
                    cutlass::complex<float>>(
                    reinterpret_cast<cutlass::complex<float>*>(pointer_),
                    capacity_, seed, dist);
            break;
        case library::NumericTypeID::kF64:
            cutlass::reference::device::BlockFillRandom<double>(
                    reinterpret_cast<double*>(pointer_), capacity_, seed, dist);
            break;
        case library::NumericTypeID::kCF64:
            cutlass::reference::device::BlockFillRandom<complex<double>>(
                    reinterpret_cast<complex<double>*>(pointer_), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kS2:
            cutlass::reference::device::BlockFillRandom<int2b_t>(
                    reinterpret_cast<int2b_t*>(pointer_), capacity_, seed,
                    dist);
            break;
        case library::NumericTypeID::kS4:
            cutlass::reference::device::BlockFillRandom<int4b_t>(
                    reinterpret_cast<int4b_t*>(pointer_), capacity_, seed,
                    dist);
            break;
        case library::NumericTypeID::kS8:
            cutlass::reference::device::BlockFillRandom<int8_t>(
                    reinterpret_cast<int8_t*>(pointer_), capacity_, seed, dist);
            break;
        case library::NumericTypeID::kS16:
            cutlass::reference::device::BlockFillRandom<int16_t>(
                    reinterpret_cast<int16_t*>(pointer_), capacity_, seed,
                    dist);
            break;
        case library::NumericTypeID::kS32:
            cutlass::reference::device::BlockFillRandom<int32_t>(
                    reinterpret_cast<int32_t*>(pointer_), capacity_, seed,
                    dist);
            break;
        case library::NumericTypeID::kS64:
            cutlass::reference::device::BlockFillRandom<int64_t>(
                    reinterpret_cast<int64_t*>(pointer_), capacity_, seed,
                    dist);
            break;
        case library::NumericTypeID::kB1:
            cutlass::reference::device::BlockFillRandom<uint1b_t>(
                    reinterpret_cast<uint1b_t*>(pointer_), capacity_, seed,
                    dist);
            break;
        case library::NumericTypeID::kU2:
            cutlass::reference::device::BlockFillRandom<uint2b_t>(
                    reinterpret_cast<uint2b_t*>(pointer_), capacity_, seed,
                    dist);
            break;
        case library::NumericTypeID::kU4:
            cutlass::reference::device::BlockFillRandom<uint4b_t>(
                    reinterpret_cast<uint4b_t*>(pointer_), capacity_, seed,
                    dist);
            break;
        case library::NumericTypeID::kU8:
            cutlass::reference::device::BlockFillRandom<uint8_t>(
                    reinterpret_cast<uint8_t*>(pointer_), capacity_, seed,
                    dist);
            break;
        case library::NumericTypeID::kU16:
            cutlass::reference::device::BlockFillRandom<uint16_t>(
                    reinterpret_cast<uint16_t*>(pointer_), capacity_, seed,
                    dist);
            break;
        case library::NumericTypeID::kU32:
            cutlass::reference::device::BlockFillRandom<uint32_t>(
                    reinterpret_cast<uint32_t*>(pointer_), capacity_, seed,
                    dist);
            break;
        case library::NumericTypeID::kU64:
            cutlass::reference::device::BlockFillRandom<uint64_t>(
                    reinterpret_cast<uint64_t*>(pointer_), capacity_, seed,
                    dist);
            break;
        default:
            break;
    }
}

void DeviceAllocation::initialize_random_host(int seed, Distribution dist) {
    if (!good()) {
        throw std::runtime_error(
                "Attempting to initialize invalid allocation.");
    }

    std::vector<uint8_t> host_data(bytes());

    switch (type_) {
        case library::NumericTypeID::kF16:
            cutlass::reference::host::BlockFillRandom<cutlass::half_t>(
                    reinterpret_cast<cutlass::half_t*>(host_data.data()),
                    capacity_, seed, dist);
            break;
        case library::NumericTypeID::kBF16:
            cutlass::reference::host::BlockFillRandom<cutlass::bfloat16_t>(
                    reinterpret_cast<cutlass::bfloat16_t*>(host_data.data()),
                    capacity_, seed, dist);
            break;
        case library::NumericTypeID::kTF32:
            cutlass::reference::host::BlockFillRandom<cutlass::tfloat32_t>(
                    reinterpret_cast<cutlass::tfloat32_t*>(host_data.data()),
                    capacity_, seed, dist);
            break;
        case library::NumericTypeID::kF32:
            cutlass::reference::host::BlockFillRandom<float>(
                    reinterpret_cast<float*>(host_data.data()), capacity_, seed,
                    dist);
            break;
        case library::NumericTypeID::kCF16:
            cutlass::reference::host::BlockFillRandom<
                    cutlass::complex<cutlass::half_t>>(
                    reinterpret_cast<cutlass::complex<cutlass::half_t>*>(
                            host_data.data()),
                    capacity_, seed, dist);
            break;
        case library::NumericTypeID::kCBF16:
            cutlass::reference::host::BlockFillRandom<
                    cutlass::complex<cutlass::bfloat16_t>>(
                    reinterpret_cast<cutlass::complex<cutlass::bfloat16_t>*>(
                            host_data.data()),
                    capacity_, seed, dist);
            break;
        case library::NumericTypeID::kCTF32:
            cutlass::reference::host::BlockFillRandom<
                    cutlass::complex<cutlass::tfloat32_t>>(
                    reinterpret_cast<cutlass::complex<cutlass::tfloat32_t>*>(
                            host_data.data()),
                    capacity_, seed, dist);
            break;
        case library::NumericTypeID::kCF32:
            cutlass::reference::host::BlockFillRandom<cutlass::complex<float>>(
                    reinterpret_cast<cutlass::complex<float>*>(
                            host_data.data()),
                    capacity_, seed, dist);
            break;
        case library::NumericTypeID::kF64:
            cutlass::reference::host::BlockFillRandom<double>(
                    reinterpret_cast<double*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kCF64:
            cutlass::reference::host::BlockFillRandom<cutlass::complex<double>>(
                    reinterpret_cast<cutlass::complex<double>*>(
                            host_data.data()),
                    capacity_, seed, dist);
            break;
        case library::NumericTypeID::kS2:
            cutlass::reference::host::BlockFillRandom<int2b_t>(
                    reinterpret_cast<int2b_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kS4:
            cutlass::reference::host::BlockFillRandom<int4b_t>(
                    reinterpret_cast<int4b_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kS8:
            cutlass::reference::host::BlockFillRandom<int8_t>(
                    reinterpret_cast<int8_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kS16:
            cutlass::reference::host::BlockFillRandom<int16_t>(
                    reinterpret_cast<int16_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kS32:
            cutlass::reference::host::BlockFillRandom<int32_t>(
                    reinterpret_cast<int32_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kS64:
            cutlass::reference::host::BlockFillRandom<int64_t>(
                    reinterpret_cast<int64_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kB1:
            cutlass::reference::host::BlockFillRandom<uint1b_t>(
                    reinterpret_cast<uint1b_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kU2:
            cutlass::reference::host::BlockFillRandom<uint2b_t>(
                    reinterpret_cast<uint2b_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kU4:
            cutlass::reference::host::BlockFillRandom<uint4b_t>(
                    reinterpret_cast<uint4b_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kU8:
            cutlass::reference::host::BlockFillRandom<uint8_t>(
                    reinterpret_cast<uint8_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kU16:
            cutlass::reference::host::BlockFillRandom<uint16_t>(
                    reinterpret_cast<uint16_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kU32:
            cutlass::reference::host::BlockFillRandom<uint32_t>(
                    reinterpret_cast<uint32_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        case library::NumericTypeID::kU64:
            cutlass::reference::host::BlockFillRandom<uint64_t>(
                    reinterpret_cast<uint64_t*>(host_data.data()), capacity_,
                    seed, dist);
            break;
        default:
            break;
    }

    copy_from_host(host_data.data());
}

void DeviceAllocation::initialize_random_sparsemeta_device(int seed,
                                                           int MetaSizeInBits) {
    if (!good()) {
        throw std::runtime_error(
                "Attempting to initialize invalid allocation.");
    }

    // Instantiate calls to CURAND here. This file takes a long time to compile
    // for this reason.

    switch (type_) {
        case library::NumericTypeID::kU16:
            cutlass::reference::device::BlockFillRandomSparseMeta<uint16_t>(
                    reinterpret_cast<uint16_t*>(pointer_), capacity_, seed,
                    MetaSizeInBits);
            break;
        case library::NumericTypeID::kU32:
            cutlass::reference::device::BlockFillRandomSparseMeta<uint32_t>(
                    reinterpret_cast<uint32_t*>(pointer_), capacity_, seed,
                    MetaSizeInBits);
            break;
        default:
            break;
    }
}

void DeviceAllocation::initialize_random_sparsemeta_host(int seed,
                                                         int MetaSizeInBits) {
    if (!good()) {
        throw std::runtime_error(
                "Attempting to initialize invalid allocation.");
    }

    std::vector<uint8_t> host_data(bytes());

    switch (type_) {
        case library::NumericTypeID::kS16:
            cutlass::reference::host::BlockFillRandomSparseMeta<uint16_t>(
                    reinterpret_cast<uint16_t*>(host_data.data()), capacity_,
                    seed, MetaSizeInBits);
            break;
        case library::NumericTypeID::kS32:
            cutlass::reference::host::BlockFillRandomSparseMeta<uint32_t>(
                    reinterpret_cast<uint32_t*>(host_data.data()), capacity_,
                    seed, MetaSizeInBits);
            break;
        default:
            break;
    }

    copy_from_host(host_data.data());
}

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Returns true if two blocks have exactly the same value
bool DeviceAllocation::block_compare_equal(library::NumericTypeID numeric_type,
                                           void const* ptr_A, void const* ptr_B,
                                           size_t capacity) {
    switch (numeric_type) {
        case library::NumericTypeID::kF16:
            return reference::device::BlockCompareEqual<half_t>(
                    reinterpret_cast<half_t const*>(ptr_A),
                    reinterpret_cast<half_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kBF16:
            return reference::device::BlockCompareEqual<bfloat16_t>(
                    reinterpret_cast<bfloat16_t const*>(ptr_A),
                    reinterpret_cast<bfloat16_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kTF32:
            return reference::device::BlockCompareEqual<tfloat32_t>(
                    reinterpret_cast<tfloat32_t const*>(ptr_A),
                    reinterpret_cast<tfloat32_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kF32:
            return reference::device::BlockCompareEqual<float>(
                    reinterpret_cast<float const*>(ptr_A),
                    reinterpret_cast<float const*>(ptr_B), capacity);

        case library::NumericTypeID::kCF32:
            return reference::device::BlockCompareEqual<
                    cutlass::complex<float>>(
                    reinterpret_cast<complex<float> const*>(ptr_A),
                    reinterpret_cast<complex<float> const*>(ptr_B), capacity);

        case library::NumericTypeID::kCF16:
            return reference::device::BlockCompareEqual<complex<half_t>>(
                    reinterpret_cast<complex<half_t> const*>(ptr_A),
                    reinterpret_cast<complex<half_t> const*>(ptr_B), capacity);

        case library::NumericTypeID::kCBF16:
            return reference::device::BlockCompareEqual<complex<bfloat16_t>>(
                    reinterpret_cast<complex<bfloat16_t> const*>(ptr_A),
                    reinterpret_cast<complex<bfloat16_t> const*>(ptr_B),
                    capacity);

        case library::NumericTypeID::kCTF32:
            return reference::device::BlockCompareEqual<complex<tfloat32_t>>(
                    reinterpret_cast<complex<tfloat32_t> const*>(ptr_A),
                    reinterpret_cast<complex<tfloat32_t> const*>(ptr_B),
                    capacity);

        case library::NumericTypeID::kF64:
            return reference::device::BlockCompareEqual<double>(
                    reinterpret_cast<double const*>(ptr_A),
                    reinterpret_cast<double const*>(ptr_B), capacity);

        case library::NumericTypeID::kCF64:
            return reference::device::BlockCompareEqual<complex<double>>(
                    reinterpret_cast<complex<double> const*>(ptr_A),
                    reinterpret_cast<complex<double> const*>(ptr_B), capacity);

        case library::NumericTypeID::kS2:
            return reference::device::BlockCompareEqual<int2b_t>(
                    reinterpret_cast<int2b_t const*>(ptr_A),
                    reinterpret_cast<int2b_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kS4:
            return reference::device::BlockCompareEqual<int4b_t>(
                    reinterpret_cast<int4b_t const*>(ptr_A),
                    reinterpret_cast<int4b_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kS8:
            return reference::device::BlockCompareEqual<int8_t>(
                    reinterpret_cast<int8_t const*>(ptr_A),
                    reinterpret_cast<int8_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kS16:
            return reference::device::BlockCompareEqual<int16_t>(
                    reinterpret_cast<int16_t const*>(ptr_A),
                    reinterpret_cast<int16_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kS32:
            return reference::device::BlockCompareEqual<int32_t>(
                    reinterpret_cast<int32_t const*>(ptr_A),
                    reinterpret_cast<int32_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kS64:
            return reference::device::BlockCompareEqual<int64_t>(
                    reinterpret_cast<int64_t const*>(ptr_A),
                    reinterpret_cast<int64_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kB1:
            return reference::device::BlockCompareEqual<uint1b_t>(
                    reinterpret_cast<uint1b_t const*>(ptr_A),
                    reinterpret_cast<uint1b_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kU2:
            return reference::device::BlockCompareEqual<uint2b_t>(
                    reinterpret_cast<uint2b_t const*>(ptr_A),
                    reinterpret_cast<uint2b_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kU4:
            return reference::device::BlockCompareEqual<uint4b_t>(
                    reinterpret_cast<uint4b_t const*>(ptr_A),
                    reinterpret_cast<uint4b_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kU8:
            return reference::device::BlockCompareEqual<uint8_t>(
                    reinterpret_cast<uint8_t const*>(ptr_A),
                    reinterpret_cast<uint8_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kU16:
            return reference::device::BlockCompareEqual<uint16_t>(
                    reinterpret_cast<uint16_t const*>(ptr_A),
                    reinterpret_cast<uint16_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kU32:
            return reference::device::BlockCompareEqual<uint32_t>(
                    reinterpret_cast<uint32_t const*>(ptr_A),
                    reinterpret_cast<uint32_t const*>(ptr_B), capacity);

        case library::NumericTypeID::kU64:
            return reference::device::BlockCompareEqual<uint64_t>(
                    reinterpret_cast<uint64_t const*>(ptr_A),
                    reinterpret_cast<uint64_t const*>(ptr_B), capacity);

        default:
            throw std::runtime_error("Unsupported numeric type");
    }
}

/// Returns true if two blocks have approximately the same value
bool DeviceAllocation::block_compare_relatively_equal(
        library::NumericTypeID numeric_type, void const* ptr_A,
        void const* ptr_B, size_t capacity, double epsilon,
        double nonzero_floor) {
    switch (numeric_type) {
        case library::NumericTypeID::kF16:
            return reference::device::BlockCompareRelativelyEqual<half_t>(
                    reinterpret_cast<half_t const*>(ptr_A),
                    reinterpret_cast<half_t const*>(ptr_B), capacity,
                    static_cast<half_t>(epsilon),
                    static_cast<half_t>(nonzero_floor));

        case library::NumericTypeID::kBF16:
            return reference::device::BlockCompareRelativelyEqual<bfloat16_t>(
                    reinterpret_cast<bfloat16_t const*>(ptr_A),
                    reinterpret_cast<bfloat16_t const*>(ptr_B), capacity,
                    static_cast<bfloat16_t>(epsilon),
                    static_cast<bfloat16_t>(nonzero_floor));

        case library::NumericTypeID::kTF32:
            return reference::device::BlockCompareRelativelyEqual<tfloat32_t>(
                    reinterpret_cast<tfloat32_t const*>(ptr_A),
                    reinterpret_cast<tfloat32_t const*>(ptr_B), capacity,
                    static_cast<tfloat32_t>(epsilon),
                    static_cast<tfloat32_t>(nonzero_floor));

        case library::NumericTypeID::kF32:
            return reference::device::BlockCompareRelativelyEqual<float>(
                    reinterpret_cast<float const*>(ptr_A),
                    reinterpret_cast<float const*>(ptr_B), capacity,
                    static_cast<float>(epsilon),
                    static_cast<float>(nonzero_floor));

        case library::NumericTypeID::kF64:
            return reference::device::BlockCompareRelativelyEqual<double>(
                    reinterpret_cast<double const*>(ptr_A),
                    reinterpret_cast<double const*>(ptr_B), capacity,
                    static_cast<double>(epsilon),
                    static_cast<double>(nonzero_floor));

        case library::NumericTypeID::kS2:
            return reference::device::BlockCompareRelativelyEqual<int2b_t>(
                    reinterpret_cast<int2b_t const*>(ptr_A),
                    reinterpret_cast<int2b_t const*>(ptr_B), capacity,
                    static_cast<int2b_t>(epsilon),
                    static_cast<int2b_t>(nonzero_floor));

        case library::NumericTypeID::kS4:
            return reference::device::BlockCompareRelativelyEqual<int4b_t>(
                    reinterpret_cast<int4b_t const*>(ptr_A),
                    reinterpret_cast<int4b_t const*>(ptr_B), capacity,
                    static_cast<int4b_t>(epsilon),
                    static_cast<int4b_t>(nonzero_floor));

        case library::NumericTypeID::kS8:
            return reference::device::BlockCompareRelativelyEqual<int8_t>(
                    reinterpret_cast<int8_t const*>(ptr_A),
                    reinterpret_cast<int8_t const*>(ptr_B), capacity,
                    static_cast<int8_t>(epsilon),
                    static_cast<int8_t>(nonzero_floor));

        case library::NumericTypeID::kS16:
            return reference::device::BlockCompareRelativelyEqual<int16_t>(
                    reinterpret_cast<int16_t const*>(ptr_A),
                    reinterpret_cast<int16_t const*>(ptr_B), capacity,
                    static_cast<int16_t>(epsilon),
                    static_cast<int16_t>(nonzero_floor));

        case library::NumericTypeID::kS32:
            return reference::device::BlockCompareRelativelyEqual<int32_t>(
                    reinterpret_cast<int32_t const*>(ptr_A),
                    reinterpret_cast<int32_t const*>(ptr_B), capacity,
                    static_cast<int32_t>(epsilon),
                    static_cast<int32_t>(nonzero_floor));

        case library::NumericTypeID::kS64:
            return reference::device::BlockCompareRelativelyEqual<int64_t>(
                    reinterpret_cast<int64_t const*>(ptr_A),
                    reinterpret_cast<int64_t const*>(ptr_B), capacity,
                    static_cast<int64_t>(epsilon),
                    static_cast<int64_t>(nonzero_floor));

        case library::NumericTypeID::kB1:
            return reference::device::BlockCompareRelativelyEqual<uint1b_t>(
                    reinterpret_cast<uint1b_t const*>(ptr_A),
                    reinterpret_cast<uint1b_t const*>(ptr_B), capacity,
                    static_cast<uint1b_t>(epsilon),
                    static_cast<uint1b_t>(nonzero_floor));

        case library::NumericTypeID::kU2:
            return reference::device::BlockCompareRelativelyEqual<uint2b_t>(
                    reinterpret_cast<uint2b_t const*>(ptr_A),
                    reinterpret_cast<uint2b_t const*>(ptr_B), capacity,
                    static_cast<uint2b_t>(epsilon),
                    static_cast<uint2b_t>(nonzero_floor));

        case library::NumericTypeID::kU4:
            return reference::device::BlockCompareRelativelyEqual<uint4b_t>(
                    reinterpret_cast<uint4b_t const*>(ptr_A),
                    reinterpret_cast<uint4b_t const*>(ptr_B), capacity,
                    static_cast<uint4b_t>(epsilon),
                    static_cast<uint4b_t>(nonzero_floor));

        case library::NumericTypeID::kU8:
            return reference::device::BlockCompareRelativelyEqual<uint8_t>(
                    reinterpret_cast<uint8_t const*>(ptr_A),
                    reinterpret_cast<uint8_t const*>(ptr_B), capacity,
                    static_cast<uint8_t>(epsilon),
                    static_cast<uint8_t>(nonzero_floor));

        case library::NumericTypeID::kU16:
            return reference::device::BlockCompareRelativelyEqual<uint16_t>(
                    reinterpret_cast<uint16_t const*>(ptr_A),
                    reinterpret_cast<uint16_t const*>(ptr_B), capacity,
                    static_cast<uint16_t>(epsilon),
                    static_cast<uint16_t>(nonzero_floor));

        case library::NumericTypeID::kU32:
            return reference::device::BlockCompareRelativelyEqual<uint32_t>(
                    reinterpret_cast<uint32_t const*>(ptr_A),
                    reinterpret_cast<uint32_t const*>(ptr_B), capacity,
                    static_cast<uint32_t>(epsilon),
                    static_cast<uint32_t>(nonzero_floor));

        case library::NumericTypeID::kU64:
            return reference::device::BlockCompareRelativelyEqual<uint64_t>(
                    reinterpret_cast<uint64_t const*>(ptr_A),
                    reinterpret_cast<uint64_t const*>(ptr_B), capacity,
                    static_cast<uint64_t>(epsilon),
                    static_cast<uint64_t>(nonzero_floor));

        // No relatively equal comparison for complex numbers.
        //
        // As a simplification, we can require bitwise equality. This avoids
        // false positives. (i.e. "pass" really means passing. "Fail" may not
        // actually mean failure given appropriate epsilon.)
        //
        case library::NumericTypeID::kCF16:
            return reference::device::BlockCompareEqual<
                    cutlass::complex<half_t>>(
                    reinterpret_cast<complex<half_t> const*>(ptr_A),
                    reinterpret_cast<complex<half_t> const*>(ptr_B), capacity);

        case library::NumericTypeID::kCF32:
            return reference::device::BlockCompareEqual<
                    cutlass::complex<float>>(
                    reinterpret_cast<complex<float> const*>(ptr_A),
                    reinterpret_cast<complex<float> const*>(ptr_B), capacity);

        case library::NumericTypeID::kCF64:
            return reference::device::BlockCompareEqual<
                    cutlass::complex<double>>(
                    reinterpret_cast<complex<double> const*>(ptr_A),
                    reinterpret_cast<complex<double> const*>(ptr_B), capacity);

        default: {
            throw std::runtime_error(std::string("Unsupported numeric type: ") +
                                     to_string(numeric_type));
        }
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Permits copying dynamic vectors into static-length vectors
template <typename TensorCoord, int Rank>
struct vector_to_coord {
    vector_to_coord(TensorCoord& coord, std::vector<int> const& vec) {
        coord[Rank - 1] = vec.at(Rank - 1);

        if (Rank > 1) {
            vector_to_coord<TensorCoord, Rank - 1>(coord, vec);
        }
    }
};

/// Permits copying dynamic vectors into static-length vectors
template <typename TensorCoord>
struct vector_to_coord<TensorCoord, 1> {
    vector_to_coord(TensorCoord& coord, std::vector<int> const& vec) {
        coord[0] = vec.at(0);
    }
};

/// Permits copying dynamic vectors into static-length vectors
template <typename TensorCoord>
struct vector_to_coord<TensorCoord, 0> {
    vector_to_coord(TensorCoord& coord, std::vector<int> const& vec) {}
};

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Element, typename Layout>
static void write_tensor_csv_static_tensor_view(std::ostream& out,
                                                DeviceAllocation& allocation) {
    Coord<Layout::kRank> extent;
    Coord<Layout::kStrideRank> stride;

    if (allocation.extent().size() != Layout::kRank) {
        throw std::runtime_error("Allocation extent has invalid rank");
    }

    if (allocation.stride().size() != Layout::kStrideRank) {
        throw std::runtime_error("Allocation stride has invalid rank");
    }

    vector_to_coord<Coord<Layout::kRank>, Layout::kRank>(extent,
                                                         allocation.extent());
    vector_to_coord<Coord<Layout::kStrideRank>, Layout::kStrideRank>(
            stride, allocation.stride());

    Layout layout(stride);
    HostTensor<Element, Layout> host_tensor(extent, layout, false);

    if (host_tensor.capacity() != allocation.batch_stride()) {
        throw std::runtime_error("Unexpected capacity to equal.");
    }

    host_tensor.copy_in_device_to_host(
            static_cast<Element const*>(allocation.data()),
            allocation.batch_stride());

    TensorViewWrite(out, host_tensor.host_view());

    out << "\n\n";
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
static void write_tensor_csv_static_type(std::ostream& out,
                                         DeviceAllocation& allocation) {
    switch (allocation.layout()) {
        case library::LayoutTypeID::kRowMajor:
            write_tensor_csv_static_tensor_view<T, layout::RowMajor>(
                    out, allocation);
            break;
        case library::LayoutTypeID::kColumnMajor:
            write_tensor_csv_static_tensor_view<T, layout::ColumnMajor>(
                    out, allocation);
            break;
        case library::LayoutTypeID::kRowMajorInterleavedK2:
            write_tensor_csv_static_tensor_view<T,
                                                layout::RowMajorInterleaved<2>>(
                    out, allocation);
            break;
        case library::LayoutTypeID::kColumnMajorInterleavedK2:
            write_tensor_csv_static_tensor_view<
                    T, layout::ColumnMajorInterleaved<2>>(out, allocation);
            break;
        case library::LayoutTypeID::kRowMajorInterleavedK4:
            write_tensor_csv_static_tensor_view<T,
                                                layout::RowMajorInterleaved<4>>(
                    out, allocation);
            break;
        case library::LayoutTypeID::kColumnMajorInterleavedK4:
            write_tensor_csv_static_tensor_view<
                    T, layout::ColumnMajorInterleaved<4>>(out, allocation);
            break;
        case library::LayoutTypeID::kRowMajorInterleavedK16:
            write_tensor_csv_static_tensor_view<
                    T, layout::RowMajorInterleaved<16>>(out, allocation);
            break;
        case library::LayoutTypeID::kColumnMajorInterleavedK16:
            write_tensor_csv_static_tensor_view<
                    T, layout::ColumnMajorInterleaved<16>>(out, allocation);
            break;
        case library::LayoutTypeID::kRowMajorInterleavedK32:
            write_tensor_csv_static_tensor_view<
                    T, layout::RowMajorInterleaved<32>>(out, allocation);
            break;
        case library::LayoutTypeID::kColumnMajorInterleavedK32:
            write_tensor_csv_static_tensor_view<
                    T, layout::ColumnMajorInterleaved<32>>(out, allocation);
            break;
        case library::LayoutTypeID::kRowMajorInterleavedK64:
            write_tensor_csv_static_tensor_view<
                    T, layout::RowMajorInterleaved<64>>(out, allocation);
            break;
        case library::LayoutTypeID::kColumnMajorInterleavedK64:
            write_tensor_csv_static_tensor_view<
                    T, layout::ColumnMajorInterleaved<64>>(out, allocation);
            break;
        case library::LayoutTypeID::kTensorNHWC:
            write_tensor_csv_static_tensor_view<T, layout::TensorNHWC>(
                    out, allocation);
            break;
        case library::LayoutTypeID::kTensorNDHWC:
            write_tensor_csv_static_tensor_view<T, layout::TensorNDHWC>(
                    out, allocation);
            break;
        case library::LayoutTypeID::kTensorNC32HW32:
            write_tensor_csv_static_tensor_view<T, layout::TensorNCxHWx<32>>(
                    out, allocation);
            break;
        case library::LayoutTypeID::kTensorNC64HW64:
            write_tensor_csv_static_tensor_view<T, layout::TensorNCxHWx<64>>(
                    out, allocation);
            break;
        case library::LayoutTypeID::kTensorC32RSK32:
            write_tensor_csv_static_tensor_view<T, layout::TensorCxRSKx<32>>(
                    out, allocation);
            break;
        case library::LayoutTypeID::kTensorC64RSK64:
            write_tensor_csv_static_tensor_view<T, layout::TensorCxRSKx<64>>(
                    out, allocation);
            break;
        default:
            throw std::runtime_error("Unhandled layout");
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Writes a tensor to csv
void DeviceAllocation::write_tensor_csv(std::ostream& out) {
    switch (this->type()) {
        case library::NumericTypeID::kF16:
            write_tensor_csv_static_type<half_t>(out, *this);
            break;

        case library::NumericTypeID::kBF16:
            write_tensor_csv_static_type<bfloat16_t>(out, *this);
            break;

        case library::NumericTypeID::kTF32:
            write_tensor_csv_static_type<tfloat32_t>(out, *this);
            break;

        case library::NumericTypeID::kF32:
            write_tensor_csv_static_type<float>(out, *this);
            break;

        case library::NumericTypeID::kF64:
            write_tensor_csv_static_type<double>(out, *this);
            break;

        case library::NumericTypeID::kS2:
            write_tensor_csv_static_type<int2b_t>(out, *this);
            break;

        case library::NumericTypeID::kS4:
            write_tensor_csv_static_type<int4b_t>(out, *this);
            break;

        case library::NumericTypeID::kS8:
            write_tensor_csv_static_type<int8_t>(out, *this);
            break;

        case library::NumericTypeID::kS16:
            write_tensor_csv_static_type<int16_t>(out, *this);
            break;

        case library::NumericTypeID::kS32:
            write_tensor_csv_static_type<int32_t>(out, *this);
            break;

        case library::NumericTypeID::kS64:
            write_tensor_csv_static_type<int64_t>(out, *this);
            break;

        case library::NumericTypeID::kB1:
            write_tensor_csv_static_type<uint1b_t>(out, *this);
            break;

        case library::NumericTypeID::kU2:
            write_tensor_csv_static_type<uint2b_t>(out, *this);
            break;

        case library::NumericTypeID::kU4:
            write_tensor_csv_static_type<uint4b_t>(out, *this);
            break;

        case library::NumericTypeID::kU8:
            write_tensor_csv_static_type<uint8_t>(out, *this);
            break;

        case library::NumericTypeID::kU16:
            write_tensor_csv_static_type<uint16_t>(out, *this);
            break;

        case library::NumericTypeID::kU32:
            write_tensor_csv_static_type<uint32_t>(out, *this);
            break;

        case library::NumericTypeID::kU64:
            write_tensor_csv_static_type<uint64_t>(out, *this);
            break;

        case library::NumericTypeID::kCF16:
            write_tensor_csv_static_type<cutlass::complex<half_t>>(out, *this);
            break;

        case library::NumericTypeID::kCF32:
            write_tensor_csv_static_type<cutlass::complex<float>>(out, *this);
            break;

        case library::NumericTypeID::kCF64:
            write_tensor_csv_static_type<cutlass::complex<double>>(out, *this);
            break;

        default:
            throw std::runtime_error("Unsupported numeric type");
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace profiler
}  // namespace cutlass
